import argparse
import numpy as np
import torch
import gym
from vae import VAE
#from coolname import generate_slug
import utils
import d4rl

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
parser.add_argument('--num_iters', type=int, default=int(1e5))
args = parser.parse_args()


device = 'cuda'

# load data
env_name = 'antmaze-'+args.dataset+'-v2'
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

print(state_dim, action_dim, max_action)
latent_dim = action_dim * 2

replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
replay_buffer.convert_selfdata('../implicit_q_learning/models/datasets_cvae/antmaze-' + args.dataset + '-v2.hdf5')
replay_selfbuffer = utils.ReplayBuffer(state_dim, action_dim)
replay_selfbuffer.convert_selfdata('../implicit_q_learning/models/datasets_cvae/antmaze-' + args.dataset + '-v2-finetune.hdf5')

states = replay_buffer.state
actions = replay_buffer.action

vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=512).to(device)
vae.load_state_dict(torch.load('../implicit_q_learning/models/cvae_1e5_varloss/vae_model_antmaze_' + args.dataset + '.pt'))

total_size = states.shape[0]

# calculate center point of expert distribution
states_expert = replay_selfbuffer.state
actions_expert = replay_selfbuffer.action
train_states = torch.from_numpy(states_expert).to(device)
train_actions = torch.from_numpy(actions_expert).to(device)
_, mean_all, std_all = vae(train_states, train_actions)
mean = torch.mean(mean_all, 0)
std = torch.mean(std_all, 0)


cn = 0
for step in range(1000):
    #idx = np.random.choice(total_size, 1, replace=False)
    #states_1 = states[idx]
    #actions_1 = actions[idx]   
    #train_states = torch.from_numpy(states_1).to(device)
    #train_actions = torch.from_numpy(actions_1).to(device)
    #_, mean1, std1 = vae(train_states, train_actions)
    #KL_loss = 0.5*((std/std1).pow(2)+torch.log((std1/std).pow(2))+((mean-mean1)/std1).pow(2) -1).mean()

    #print(KL_loss)
    
    idx_self = np.random.choice(replay_selfbuffer.size, 1, replace=False)
    states_2 = replay_selfbuffer.state[idx_self]
    actions_2 = replay_selfbuffer.action[idx_self]  
    train_states = torch.from_numpy(states_2).to(device)
    train_actions = torch.from_numpy(actions_2).to(device)
    _, mean1, std1 = vae(train_states, train_actions)
    KL_loss = 0.5*((std/std1).pow(2)+torch.log((std1/std).pow(2))+((mean-mean1)/std1).pow(2) -1).mean()
    #print(KL_loss)
    
    
    if KL_loss.item() >= 1:
        cn += 1
    print(KL_loss.item())
    #print('======================================================================')
print(cn)

